# Custom Tensorflow operations for HDRnet written in C++ and CUDA.

load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py", "tf_kernel_library")

package(
    licenses = ["notice"],  # Apache 2.0
)

cc_library(
    name = "numerics",
    hdrs = ["numerics.h"],
)

cc_library(
    name = "bilateral_slice_apply",
    srcs = ["bilateral_slice_apply.cc"],
    hdrs = ["bilateral_slice_apply.h"],
    deps = [
        ":numerics",
        "//array",
    ],
)

# TF kernels implementing fused bilateral_slice_apply and its gradient.
tf_kernel_library(
    name = "bilateral_slice_apply_tf_kernel",
    srcs = [
        "bilateral_slice_apply_op.cc",
    ],
    gpu_srcs = [
        "bilateral_slice_apply.cu.cc",
    ],
    deps = [
        ":bilateral_slice_apply",
        ":numerics",
        "//array",
        "//eigen3",
        "//tensorflow/core:framework",
        "//tensorflow/core:framework_lite",
    ],
)

# Wraps ":bilateral_slice_apply_tf_kernel" as a TF op in Python.
# To link correctly, the Python rule needs to depend on both this
# as well as ":bilateral_slice_apply_tf_kernel".
tf_gen_op_wrapper_py(
    name = "bilateral_slice_apply_py_tf_op",
    out = "gen_bilateral_slice_apply_ops.py",
    deps = [":bilateral_slice_apply_tf_kernel"],
)

cc_library(
    name = "bilateral_slice",
    srcs = ["bilateral_slice.cc"],
    hdrs = ["bilateral_slice.h"],
    deps = [
        ":numerics",
        "//array",
    ],
)

# TF kernels implementing bilateral_slice and its gradient.
tf_kernel_library(
    name = "bilateral_slice_tf_kernel",
    srcs = [
        "bilateral_slice_op.cc",
    ],
    gpu_srcs = [
        "bilateral_slice.cu.cc",
    ],
    deps = [
        ":bilateral_slice",
        ":numerics",
        "//array",
        "//eigen3",
        "//tensorflow/core:framework",
        "//tensorflow/core:lib",
    ],
)

# Wraps ":bilateral_slice_tf_kernel" as a TF op in Python.
# To link correctly, the Python rule needs to depend on both this
# as well as ":bilateral_slice_tf_kernel".
tf_gen_op_wrapper_py(
    name = "bilateral_slice_py_tf_op",
    out = "gen_bilateral_slice_ops.py",
    deps = [":bilateral_slice_tf_kernel"],
)
